# -*- coding: utf-8 -*-
"""
------------------------------------------------------------------------------
    File Name:  lstm_demo
    Author   :  wanwei1029
    Date     :  2018-10-12
    Desc     :
------------------------------------------------------------------------------
"""
from keras.models import Sequential
from keras.layers.recurrent import LSTM
import numpy as np


def demo():
    """
    LSTM参数说明：
    units：输出维度。
    input_dim：输入维度，当使用该层为模型首层时，应指定该值（或等价的指定input_shape)
    input_length：当输入序列的长度固定时，该参数为输入序列的长度。当需要在该层后连接Flatten层，然后又要连接Dense层时，
        需要指定该参数，否则全连接的输出无法计算出来。
    """
    model = Sequential()
    # model.add(LSTM(128, input_dim=64, input_length=5, return_sequences=True))
    model.add(LSTM(128, input_shape=(3, 2), return_sequences=True))
    model.add(LSTM(256, return_sequences=False))
    input_array = np.random.randint(10, size=(3, 2))
    model.compile('rmsprop', 'mse')
    output_array = model.predict(input_array)
    print(output_array)


if __name__ == '__main__':
    test_method = "demo"
    if test_method == "demo":
        demo()
